"""
[Image]

## LaVCa
python -m evaluation.eval_image_similarity \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model gpt-4o-2024-08-06 \
    --correct_model default \
    --candidate_num 50 \
    --key_num 5 \
    --temperature 0.05 \
    --filter_th 0.15 \
    --cc_method spearman \
    --gen_model FLUX.1-schnell \
    --eval_model CLIP-ViT-B-32 \
    --device cuda

## BrainSCUBA
python -m evaluation.eval_image_similarity \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model brainscuba \
    --keywords_model default \
    --correct_model default \
    --candidate_num -1 \
    --temperature -1 \
    --key_num -1 \
    --cc_method spearman \
    --gen_model FLUX.1-schnell \
    --eval_model CLIP-ViT-B-32 \
    --device cpu
"""

import torch
import argparse
import os
import json
from tqdm import tqdm
from utils.utils import (
    search_best_layer, TrnVal, gen_nulldistrib_gauss, fdr_correction, 
    make_filename, collect_fmri_byroi_for_nsd, create_volume_index_and_weight_map
)
import numpy as np
from PIL import Image
from utils.nsd_access import NSDAccess
import scipy
from himalaya.scoring import correlation_score
from transformers import AutoProcessor, CLIPVisionModelWithProjection
# nltk.download('punkt')

torch.manual_seed(42)

def load_resp_wholevoxels_for_nsd(subject_name, dataset="all", atlas="streams"):
    resp_trn = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="TRAIN",
                                                         atlasname=atlas)
    resp_val = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="VALID",
                                                         atlasname=atlas)

    return TrnVal(trn=resp_trn, val=resp_val)

    
def main(args):
    score_root_path = "./data/nsd/encoding"
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])
    nsda = NSDAccess('./data/NSD')
    sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    
    if args.eval_model == "CLIP-ViT-B-32":
        model_id = "openai/clip-vit-base-patch32"
        model = CLIPVisionModelWithProjection.from_pretrained(model_id).to(args.device)
        processor = AutoProcessor.from_pretrained(model_id)
    else:
        NotImplementedError(f"Model {model_name} is not supported.")
    
    gen_img_width = 512
    gen_img_height = 512
    width = 224
    height = 224
    # シード値を固定
    seed = 42  # 任意のシード値
    generator = torch.manual_seed(seed)
    
    for subject_name in args.subject_names:
        print(subject_name)
        filename = make_filename(args.reduce_dims[0:2])

        print(f"Modality: {modality}, Modality hparams: {modality_hparam}, Feature: {model_name}, Filename: {filename}")
        # loading the selected layer per subject
        model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"
        if args.layer_selection == "best":
            target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
        else:
            target_best_cv_layer = args.layer_selection
        print(f"Best layer: {target_best_cv_layer}")

        # Random Select
        # np.random.seed(seed=42)
        # target_top_ind = np.random.choice(target_top_ind, len(target_top_ind), replace=False)
        volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
            subject_name=subject_name,
            file_type=file_type,
            threshold=threshold,
            model_score_dir=model_score_dir,
            target_best_cv_layer=target_best_cv_layer,
            filename=filename,
            nsda=nsda,
            atlasnames=args.atlasname  # args.atlasname がリストであることを想定
        )

        stim_root_path = "./data/stim_features/nsd"
        if args.reduce_dims[0] != "default":
            try:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True).item()
            except:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True)
        else:
            reducer_projector = None
        print(reducer_projector)

        resp = load_resp_wholevoxels_for_nsd(subject_name, "all", atlas="cortex")
        
        # Load stim features
        stim_dir = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}"
        stim = collect_stim_for_nsd(subject_name, args.modality, stim_dir)
        
        cc_dict_original = {"tr": [], "te": []}
        cc_dict_shuffled = {"tr": [], "te": []}
        embs_list = []
        resp_trn_list = []
        resp_val_list = []
        weight_indices = []

        for idx, voxel_index in enumerate(tqdm(volume_index)):
            try:
                print(f"voxel_index: {voxel_index}")
                vindex_pad = str(voxel_index).zfill(5)
                resp_save_path = f"./data/nsd/insilico/{subject_name}/{args.dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/voxel{vindex_pad}"

                temp_file_path = f"{resp_save_path}/temp_gen_image_{args.caption_model}.txt"
                if os.path.exists(temp_file_path):
                    print(f"Simulation for {voxel_index} is being processed.")
                    continue
            
                print(f"Now processing: {voxel_index}")
                open(temp_file_path, 'a').close()
                if args.caption_model in ["MeaCap", "MeaCap_cc3m", "default"]:
                    cc_sentence_sim_tr_file_path = f"{resp_save_path}/{args.cc_method}_cc_gen_img_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}_tr.npy"
                    cc_sentence_sim_te_file_path = f"{resp_save_path}/{args.cc_method}_cc_gen_img_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}_te.npy"

                else:
                    cc_sentence_sim_tr_file_path = f"{resp_save_path}/{args.cc_method}_cc_gen_img_sim_{args.caption_model}_tr.npy"
                    cc_sentence_sim_te_file_path = f"{resp_save_path}/{args.cc_method}_cc_gen_img_sim_{args.caption_model}_te.npy"
                
                weight_index = weight_index_map[voxel_index]
                resp_trn = resp.trn[:,weight_index]
                resp_val = resp.val[:,weight_index]
                resp_trn_list.append(resp_trn)
                resp_val_list.append(resp_val)
                weight_indices.append(weight_index)
                    
                if args.caption_model == "brainscuba":
                    caption_file_path = os.path.join(resp_save_path, f"caption_{args.caption_model}_tau150.0.txt")
                    with open(caption_file_path, "r") as f:
                        caption = f.read()
                else:
                    keys_and_text_file_path = os.path.join(resp_save_path, f"keys_and_text_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}.json")
                    with open(keys_and_text_file_path, "r") as f:
                        keys_and_text = json.load(f)
                    caption = keys_and_text["text"]
                print(caption)
                
                if args.caption_model == "brainscuba":
                    image_base_name = os.path.basename(caption_file_path).replace("caption", f"{args.gen_model}_{gen_img_height}x{gen_img_width}px").replace(".txt", ".png")
                    embs_base_name = os.path.basename(caption_file_path).replace("caption", f"{args.eval_model}_eval_{args.gen_model}").replace(".txt", ".npy")
                else:
                    image_base_name = os.path.basename(keys_and_text_file_path).replace("keys_and_text", f"{args.gen_model}_{gen_img_height}x{gen_img_width}px").replace(".json", ".png")
                    embs_base_name = os.path.basename(keys_and_text_file_path).replace("keys_and_text", f"{args.eval_model}_eval_{args.gen_model}").replace(".json", ".npy")
                image_save_dir = f"{resp_save_path}/gen_images"
                os.makedirs(image_save_dir, exist_ok=True)
                embs_save_path = f"{image_save_dir}/{embs_base_name}"
                
                if os.path.exists(cc_sentence_sim_tr_file_path) and os.path.exists(cc_sentence_sim_te_file_path):
                    print(f"Already processed: {voxel_index}")
                    embs_list.append(np.load(embs_save_path))
                    cc_tr = np.load(cc_sentence_sim_tr_file_path)
                    cc_te = np.load(cc_sentence_sim_te_file_path)
                    cc_dict_original["tr"].append(cc_tr)
                    cc_dict_original["te"].append(cc_te)
                    continue
                
                image_save_path = f"{image_save_dir}/{image_base_name}"
                
                # Load the image
                gen_img = Image.open(image_save_path)
                # Resize the image
                gen_img = gen_img.resize((width, height))
                # Convert the image to a tensor
                dummy_text = "dummy"
                gen_img_tensor = processor(dummy_text, gen_img, return_tensors="pt")
                gen_img_tensor = gen_img_tensor["pixel_values"].to(args.device)
                # Get the image features
                embs = model(gen_img_tensor).image_embeds
                embs = embs.cpu().detach().numpy()
                print(embs.shape)
                # Save the image features
                np.save(embs_save_path, embs)

                # Collect embeddings and responses for shuffling later
                embs_list.append(embs)

                # Calculate the similarity between the generated image and the original image
                cos_sim_tr = sim_func(torch.tensor(embs).to(args.device), torch.tensor(stim.trn).to(args.device)).cpu().detach().numpy()
                cos_sim_te = sim_func(torch.tensor(embs).to(args.device), torch.tensor(stim.val).to(args.device)).cpu().detach().numpy()
                
                if args.cc_method == "pearson":
                    cc_tr = correlation_score(resp_trn, cos_sim_tr)
                    cc_te = correlation_score(resp_val, cos_sim_te)
                elif args.cc_method == "spearman":
                    cc_tr = scipy.stats.spearmanr(resp_trn, cos_sim_tr).statistic
                    cc_te = scipy.stats.spearmanr(resp_val, cos_sim_te).statistic
                print(cc_tr, cc_te)
                
                cc_dict_original["tr"].append(cc_tr)
                cc_dict_original["te"].append(cc_te)

                if args.caption_model in ["MeaCap", "default"]:
                    np.save(cc_sentence_sim_tr_file_path, cc_tr)
                    np.save(cc_sentence_sim_te_file_path, cc_te)
                else:
                    np.save(cc_sentence_sim_tr_file_path, cc_tr)
                    np.save(cc_sentence_sim_te_file_path, cc_te)

            finally:
                try:
                    os.remove(temp_file_path)
                except:
                    pass

        atlasname_savename = "_".join(args.atlasname)
        if args.caption_model in ["MeaCap", "default"]:
            if args.filter_th:
                all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_gen_img_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}.npy"
            else:
                all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_gen_img_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.candidate_num}cands_cmodel_{args.correct_model}.npy"
        
        else:
            all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_gen_img_sim_{args.caption_model}.npy"
        all_voxels_cc_save_path = f"./data/nsd/insilico/{subject_name}/{args.dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/{all_cc_savename}"

        cc_dict_original["tr"] = np.array(cc_dict_original["tr"])
        cc_dict_original["te"] = np.array(cc_dict_original["te"])
        np.save(all_voxels_cc_save_path, cc_dict_original)


        # perf type = block
        pvalue_corrected_dict_original = {"tr": [], "te": []}
        for trnval in ["tr", "te"]:
            if trnval == "tr":
                n_sample = resp.trn.shape[0]
            else:
                n_sample = resp.val.shape[0]
            y_val_pred = cc_dict_original[trnval]
            rccs = gen_nulldistrib_gauss(len(volume_index), n_sample)        
            significant_voxels, pvalue_corrected = fdr_correction(cc_dict_original[trnval], rccs)
            print(f"Number of significant voxels: {len(significant_voxels)}")
            print(f"pvalue_corrected: {pvalue_corrected}")
            pvalue_corrected_dict_original[trnval] = pvalue_corrected
        np.save(f"{all_voxels_cc_save_path.replace('cc_gen_img_sim_', 'cc_pvalues_corrected_gen_img_sim_')}", pvalue_corrected_dict_original)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_captioner",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--voxel_selection",
        nargs="*",
        type=str,
        required=True,
        help="Selection method of voxels. Implemented type are 'uv' and 'share'."
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=False,
        default="best",
    )
    parser.add_argument(
        "--caption_model",
        type=str,
        required=True,
        help="Name of the captioning model to use."
    )
    parser.add_argument(
        "--keywords_model",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--correct_model",
        type=str,
        required=True,
        choices=["None", "default", "gpt-4o-2024-08-06"],
        help="Name of the correction model to use."
    )
    parser.add_argument(
        "--candidate_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--key_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        required=True
    )
    parser.add_argument(
        "--filter_th",
        type=float,
        required=False,
    )
    parser.add_argument(
        "--cc_method",
        type=str,
        required=True,
        choices=["spearman", "pearson"],
    )
    parser.add_argument(
        "--gen_model",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--eval_model",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to use."
    )
    parser.add_argument(
        "--embs_only",
        action="store_true",
        required=False,
        default=False
    )
    args = parser.parse_args()
    main(args)